import argparse
from pathlib import Path
from typing import Dict, List, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd


# -------------------------------------------------------------------
# Filename parsing utilities
# -------------------------------------------------------------------

def parse_filename(csv_path: Path) -> Tuple[str, bool, List[float]]:
    """Parse dataset, estimate flag, and stratum_ratio from a CSV filename.

    Expected pattern:
      rel_rmse__dataset=<dataset>__estimate_pi_b=<true|false>__...__stratum_ratio=a-b-....
    """
    name = csv_path.name
    if not (name.startswith("rel_rmse__") and name.endswith(".csv")):
        return "unknown", False, []

    parts = name[len("rel_rmse__") : -len(".csv")].split("__")

    dataset = "unknown"
    estimate_flag = False
    ratio_values: List[float] = []

    for part in parts:
        if part.startswith("dataset="):
            dataset = part.split("=", 1)[1]
        elif part.startswith("estimate_pi_b="):
            estimate_flag = part.split("=", 1)[1].lower() == "true"
        elif part.startswith("stratum_ratio="):
            ratio_str = part.split("=", 1)[1]
            try:
                ratio_values = [float(x) for x in ratio_str.split("-") if x]
            except ValueError:
                ratio_values = []

    return dataset, estimate_flag, ratio_values


def load_rel_rmse(csv_path: Path) -> Dict[str, float]:
    """Load rel-RMSE values from one CSV into {method: rel_rmse}."""
    df = pd.read_csv(csv_path)
    rel_rmse: Dict[str, float] = {}
    for _, row in df.iterrows():
        rel_rmse[str(row["method"])] = float(row["rel_rmse"])
    return rel_rmse


def ratio_to_x_value(stratum_ratio: List[float]) -> float:
    """Convert stratum_ratio into x = n1/n2. If not applicable, return NaN."""
    if len(stratum_ratio) < 2:
        return float("nan")
    denom = stratum_ratio[1]
    if denom == 0:
        return float("inf")
    return stratum_ratio[0] / denom


def title_from_dataset_id(dataset_id: str) -> str:
    mapping = {
        "letter": "Letter",
        "optdigits": "OptDigits",
        "pendigits": "PenDigits",
        "sat": "SatImage",
    }
    return mapping.get(dataset_id, dataset_id.title())


# -------------------------------------------------------------------
# Aggregation
# -------------------------------------------------------------------

def aggregate_series_by_dataset(
    csv_files: List[Path], estimate_flag: bool
) -> Dict[str, Dict[str, List[Tuple[float, float]]]]:
    """Collect (x, y) pairs for each dataset and method, for a given estimate flag.

    Returns:
        {
          dataset_id: {
            method_key: [(x, rel_rmse), ...],
            ...
          },
          ...
        }
    """
    data: Dict[str, Dict[str, List[Tuple[float, float]]]] = {}

    for csv_path in csv_files:
        dataset_id, est_flag, ratio_vals = parse_filename(csv_path)
        if dataset_id == "unknown" or est_flag != estimate_flag:
            continue

        x_val = ratio_to_x_value(ratio_vals)
        rel_rmse_dict = load_rel_rmse(csv_path)

        ds_dict = data.setdefault(dataset_id, {})
        for method_key, rr in rel_rmse_dict.items():
            ds_dict.setdefault(method_key, []).append((x_val, rr))

    # sort each series by x for cleaner lines
    for dataset_id in data:
        for method_key in data[dataset_id]:
            series = data[dataset_id][method_key]
            data[dataset_id][method_key] = sorted(
                series, key=lambda t: (np.isnan(t[0]), t[0])
            )

    return data


# -------------------------------------------------------------------
# Plotting
# -------------------------------------------------------------------

def plot_four_panel(
    data_by_dataset: Dict[str, Dict[str, List[Tuple[float, float]]]],
    outfile: Path,
) -> None:
    """Plot a 1x4 panel of rel-RMSE vs stratum ratio with a shared legend on top."""
    methods_order = [
        "J_naive_ips",
        "J_balanced_ips",
        "J_weighted_ips",
        "J_dr_balanced_ips",
        "J_optimal_ips",
        "J_dr_optimal_ips"
    ]
    method_labels = {
        "J_naive_ips": "IPS",
        "J_balanced_ips": "Balanced IPS",
        "J_weighted_ips": "Weighted IPS",
        "J_dr_balanced_ips": "DR-bIPS (Kallus et al.)",
        "J_optimal_ips": "Optimal IPS (This work)",
        "J_dr_optimal_ips": "DR-Optimal IPS (This work)"
    }

    # 6 distinct colors + markers
    colors = ["tab:red", "tab:blue", "tab:purple", 
              "tab:brown", "tab:green", "tab:orange"]
    markers = ["o", "x", "s", "^", "D", "P"]

    dataset_order = ["letter", "optdigits", "pendigits", "sat"]
    any_present = any(d in data_by_dataset for d in dataset_order)
    if not any_present:
        print("No datasets to plot for:", outfile)
        return

    fig, axs = plt.subplots(1, 4, figsize=(16, 4), sharey=True)

    for i, dataset_id in enumerate(dataset_order):
        ax = axs[i]
        if dataset_id not in data_by_dataset:
            ax.set_visible(False)
            continue

        series_dict = data_by_dataset[dataset_id]

        for method_key, color, marker in zip(methods_order, colors, markers):
            points = series_dict.get(method_key, [])
            if not points:
                continue

            xs = [p[0] for p in points]
            ys = [p[1] for p in points]

            ax.plot(
                xs,
                ys,
                label=method_labels.get(method_key, method_key),
                color=color,
                marker=marker,
                markersize=7,
                linewidth=1.8,
            )

        ax.set_title(title_from_dataset_id(dataset_id))
        ax.set_xscale("log")
        ax.set_yscale("log")
        ax.set_xlabel("Stratum Size Ratio (n1/n2)", fontsize=14)
        if i == 0:
            ax.set_ylabel("Relative-RMSE", fontsize=14)

        ax.grid(True, which="both", linestyle="--", linewidth=0.6)

    # ----------------------
    # Global legend on top
    # ----------------------
    # Collect handles/labels from all visible axes
    handle_by_label = {}
    for ax in axs:
        if not ax.get_visible():
            continue
        h, l = ax.get_legend_handles_labels()
        for hh, ll in zip(h, l):
            handle_by_label.setdefault(ll, hh)  # first one wins

    # Order legend entries according to methods_order
    legend_labels = []
    legend_handles = []
    for method_key in methods_order:
        label = method_labels.get(method_key, method_key)
        if label in handle_by_label:
            legend_labels.append(label)
            legend_handles.append(handle_by_label[label])

    # First, layout subplots but leave some room at the top
    # (top = 0.85 means axes occupy [0, 0.85] vertically)
    plt.tight_layout(rect=[0, 0, 1, 0.85])

    # Then put the legend in the free space inside the figure [0,1]
    if legend_handles:
        fig.legend(
            legend_handles,
            legend_labels,
            loc="upper center",
            bbox_to_anchor=(0.5, 0.96),  # <= 1.0 so it won't be cropped
            ncol=len(legend_handles),
            fontsize=11,
            frameon=True,
            fancybox=True,
            handlelength=2.5,
            columnspacing=1.2,
            handletextpad=0.6,
        )

    outfile.parent.mkdir(parents=True, exist_ok=True)
    plt.savefig(outfile.as_posix(), dpi=300)
    plt.close(fig)
    print(f"Saved figure: {outfile}")


# -------------------------------------------------------------------
# Main
# -------------------------------------------------------------------

def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--log_dir",
        type=str,
        default="log/2_loggers_alpha_0.95",
        help="Directory containing rel_rmse__*.csv files",
    )
    args = parser.parse_args()

    input_dir = Path(args.log_dir)
    if not input_dir.exists():
        print(f"Directory not found: {input_dir}")
        return

    csv_files = sorted(input_dir.glob("rel_rmse__*.csv"))
    if not csv_files:
        print(f"No CSV files found in {input_dir}")
        return

    data_known = aggregate_series_by_dataset(csv_files, estimate_flag=False)
    data_estimated = aggregate_series_by_dataset(csv_files, estimate_flag=True)

    plot_four_panel(data_known, input_dir / "known.png")
    plot_four_panel(data_estimated, input_dir / "estimated.png")


if __name__ == "__main__":
    main()
